Classification in Machine Learning

Classification is a supervised machine learning method where the model tries to predict the correct label of a given input data. In classification, the model is fully trained using the training data, and then it is evaluated on test data before being used to perform prediction on new unseen data.

For example a classification model might be trained on dataset of images labeled as either dogs or cats and it can be used to predict the class of new and unseen images as dogs or cats based on their features such as color, texture and shape.

The main goal of the classification algorithm is to identify the category of a given dataset, and these algorithms are mainly used to predict the output of the categorical data. Unlike regression, the resultant output is not a value but a category. The input variable (x) predicts a discrete output (y), representing specific categories or classes. Example- A model is trained on labeled email datasets categorized as "spam" or "not spam" and predicts whether new emails are spam based on features like word frequency, presence of specific keywords, subject line, etc. The algorithm which implements the classification on a dataset is known as a classifier.

Types of classifications-

  1. Binary classification- The input is classified into one of two possible categories. For example, it is used to determine whether a credit card transaction is fraudulent or legitimate, if a person is diabetic or not, the sex of a person, etc.

  2. Multiclass classification- The input is classified into one of several classes—E.g., product categorization, plant species identification, handwriting recognition, etc.

Image 1: Sample data distribution to show different types of classification

Why not use Linear Regression for Classification?

Linear regression is designed to predict continuous values, not categories. While it's tempting to use it for classification by applying a threshold (e.g., predict class 1 if output > 0.5), this approach quickly falls apart.

The problem is that linear regression outputs are unbounded — they can go below 0 or above 1 — making them unsuitable for interpreting as probabilities. It’s also sensitive to outliers: a single extreme point can distort predictions and shift the decision threshold unpredictably.

Image 2: Sample data distribution to show sensitivity of linear regression towards outliers

This is where logistic regression comes in. It models the probability of a class using the sigmoid function, which naturally squashes outputs to the [0, 1] range. By turning linear combinations of input features into probabilities, logistic regression offers a more stable and interpretable foundation for classification.

Image 3: A possible hypothesis learnt by a classifier like logistic regression

Introduction to Logistic Regression

Logistic regression has its roots in statistics, dating back to the early 20th century. It was formally introduced by statistician David Cox in 1958, originally to model binary outcomes such as survival or death. Over time, it became a foundational tool not just in epidemiology and economics, but also in machine learning, where it is now considered a standard baseline for binary and multiclass classification tasks.

Logistic regression uses a logistic function called a sigmoid function to map predictions and their probabilities. This makes it ideal for binary and multiclass classification tasks where the model must decide between discrete categories.

Binary Classification in LR

Image 4: A visual representation of binary classification

Definition:
Binary classification involves predicting one of two classes (e.g., 0 or 1, “no” or “yes”). The model estimates the probability

\(P(Y = 1 \mid X)\), where a typical decision rule is:

predict 1 if \(\sigma(z)\)≥0.5; otherwise, predict 0.

Real-World Examples:

Mechanics of Logistic Regression

This logit function illustrates how the log-odds of the target class, or logit, of the outcome is modeled as a linear function of the predictors. This duality between linear formulations in the predictor space and non-linear probability outcomes is a unique and powerful feature of logistic regression.

The sigmoid function refers to an S-shaped curve that converts any real value to a range between 0 and 1. It has a non-negative derivative at each point and exactly one inflection point.

\(\sigma(z) = \frac{1}{1 + e^{- z}}\)

Where:

def sigmoid(z):

return 1.0 / (1 + np.exp(-z))

Image 5: A visual representation of fitting a sigmoid function

To understand the mathematical derivation further, check out geek-for-geek’s blog on logistic regression!

What Is a Cost Function for a Logistic Classifier?

A Cost Function is a mathematical tool used to measure how well a machine learning model is performing. Specifically, it calculates the error — the difference between the predicted outputs and the actual outcomes. In simple terms, it tells us how wrong the model is.

For logistic regression, the cost function we use is called Cross-Entropy Loss, also known as Log Loss. Unlike simple squared error (used in linear regression), cross-entropy is better suited for classification tasks because it evaluates the performance of predicted probabilities.

Where:

Intuition Behind the Formula:

The cost is low when the predicted probability\(\widehat{y}\)​ is close to the actual label \(y\), and it increases sharply as the prediction moves away from the truth. This behavior reflects the principle of likelihood maximization, i.e. minimizing the cost function helps the model assign higher probabilities to the correct labels.

Assumptions of Logistic Regression

Logistic regression is a powerful classification technique, but it relies on several key assumptions that ensure the validity and interpretability of the model.

1. Binary Dependent Variable

The response variable should be binary, taking on only two possible outcomes (e.g., yes/no, success/failure). For more than two categories, multinomial or ordinal logistic regression is more appropriate.

2. No Multicollinearity

Independent variables should not be highly correlated with one another. High multicollinearity can distort coefficient estimates and reduce model reliability. This can be assessed using the Variance Inflation Factor (VIF).

3. Linearity of Log-Odds

There should be a linear relationship between the independent variables and the log-odds of the outcome. While the outcome itself is not linear, the logit transformation assumes linearity with predictors.

4. Adequate Sample Size

Logistic regression requires a reasonably large sample size, particularly when dealing with imbalanced classes. A common guideline is at least 10 observations per predictor for the least frequent outcome.

5. No Extreme Outliers

Outliers in the independent variables can unduly influence the model. They should be identified (e.g., using Cook’s distance) and handled appropriately, either by removal, transformation, or careful interpretation.

6. Independence of Observations

Each observation should be independent of the others. Violations of this assumption, such as repeated measurements or clustered data, require alternative modeling approaches like mixed-effects models.

Multiclass Logistic Regression

Multiclass logistic regression extends the binary version to handle classification problems involving three or more discrete classes. Unlike binary logistic regression, which uses the sigmoid function to map predictions to probabilities between 0 and 1, multiclass logistic regression typically uses the softmax function, which generalizes the sigmoid to multiple classes and ensures that the sum of the predicted probabilities across all classes is 1.

Suppose you're building a model to classify types of fruit based on features like color, weight, and texture. The output classes might be: apple, banana, and orange. A multiclass logistic regression model learns to assign a probability to each of the three classes for any new input and classifies the input as the one with the highest probability.

Image 6: A visual representation of multiclass logistic regression classifier

Approaches to Multiclass Classification

There are two main strategies to implement multiclass classification using logistic regression:

  1. One-vs-Rest (OvR): Also known as One-vs-All, this strategy involves training a separate binary classifier for each class. Each model learns to distinguish one class from all the others. For instance, in a digit classification task (0–9), ten classifiers are trained, each responsible for detecting one digit versus the rest. During prediction, the class whose classifier outputs the highest probability is selected. OvR is easy to implement and works well in many scenarios but may become less reliable if the dataset is imbalanced across classes.

  2. Multinomial (Softmax) Logistic Regression: This approach directly models the probabilities of all classes in one go using the softmax function. It calculates the probability of a class by taking the exponent of that class’s linear combination of features, divided by the sum of exponentials for all classes. This ensures all predicted probabilities are between 0 and 1 and sum to 1. While more mathematically elegant and often preferred for large multiclass problems, it can be slightly more computationally intensive than OvR.

Here, K represents the number of elements in the vector \(z\ \)and \(i,\ j\) iterate over all the elements in the vector.

Beyond Basic Classification

Multilabel Classification
Each input can belong to multiple classes. Example: recognizing faces in a photo — the output might be [1, 0, 1] meaning Alice and Charlie were detected.

Multioutput Classification
Each input has multiple outputs, each being a multiclass label. For example, predicting both the weather condition (sunny/rainy/snowy) and temperature range (hot/warm/cold).

Implementing Logistic Regression

Logistic regression is one of the most widely used classification algorithms in machine learning. To better understand how it works, let’s look at two implementations: one from scratch (using only NumPy) and one using Scikit-learn.

1. Logistic Regression From Scratch (Binary Classification)

Here, we build a simple logistic regression model using the sigmoid function and gradient descent.

Image 7 : Binary LR from scratch

2. Logistic Regression Using Scikit-learn

Using Scikit-learn greatly simplifies training and evaluation:

Image 8: Scikit Learn’s implementation of Logistic Regression

Logistic regression might sound fancy, but at its core, it’s a simple, powerful way to help machines make decisions. Whether you’re just starting out or building real-world models, it’s a great first step into the world of classification.

If you're interested in exploring beyond the basics, here are some curated resources: